{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NZKUpjEb6wTZ",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "!pip install nibabel\n",
        "!pip install torchio\n",
        "!pip install onnx\n",
        "!pip install onnxruntime"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "B6KLtDsD7BOU",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torchvision\n",
        "from torchvision import datasets,transforms, models\n",
        "import torchvision.transforms.functional as TF\n",
        "import nibabel as nib\n",
        "from pathlib import Path\n",
        "import onnxruntime\n",
        "from sklearn.metrics import precision_recall_fscore_support\n",
        "from scipy import stats\n",
        "import onnx"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "CiBuxhgm7NMB",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "# to match fMRI ICA 100 components\n",
        "batch_size = 100"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "TnJ0C5rk7Nvd",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "class MultilayerPerceptron2(nn.Module):\n",
        "  # whatever op layer is becomes num ip for next hidden layer\n",
        "  def __init__(self, input_size=45*54*45, output_size=58):\n",
        "    super().__init__()\n",
        "    N = 200\n",
        "    self.d1 = nn.Linear(input_size, N)\n",
        "    self.d2 = nn.Linear(N, N)\n",
        "    self.d3 = nn.Linear(N, N)\n",
        "    self.d4 = nn.Linear(N, N)\n",
        "    self.d5 = nn.Linear(N, output_size)\n",
        "    self.dropout = nn.Dropout(0.66)\n",
        "    self.flat = nn.Flatten()\n",
        "\n",
        "\n",
        "  def forward(self,X):\n",
        "\n",
        "    X = self.flat(X) #X.view(-1,45*54*45)\n",
        "    X = F.relu(self.d1(X))\n",
        "    X = F.relu(self.d3(X))\n",
        "    X = self.dropout(X)\n",
        "    X = F.relu(self.d4(X))\n",
        "    X = self.d5(X)\n",
        "    #X = torch.squeeze(X)\n",
        "\n",
        "    return X"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Download the mlp model weights from below link as provided in GitHub repo: https://github.com/AmmarPL/fMRI-Classification-JHU/ <br>\n",
        "Model Weights - https://livejohnshopkins-my.sharepoint.com/:f:/g/personal/apallik1_jh_edu/Eqo4DojG33pBquC3_zCNYhYBSTcUS6Ppfhl9OAVF_erlZQ?e=bfn0E7"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vw5SBCfh7NzE",
        "outputId": "c4a1696a-2109-4f78-e15f-08300250be3a",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Device is cpu\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "MultilayerPerceptron2(\n",
              "  (d1): Linear(in_features=109350, out_features=200, bias=True)\n",
              "  (d2): Linear(in_features=200, out_features=200, bias=True)\n",
              "  (d3): Linear(in_features=200, out_features=200, bias=True)\n",
              "  (d4): Linear(in_features=200, out_features=200, bias=True)\n",
              "  (d5): Linear(in_features=200, out_features=58, bias=True)\n",
              "  (dropout): Dropout(p=0.66, inplace=False)\n",
              "  (flat): Flatten(start_dim=1, end_dim=-1)\n",
              ")"
            ]
          },
          "execution_count": 5,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "device = 'cuda' if torch.cuda.is_available() else \"cpu\"\n",
        "print(f'Device is {device}')\n",
        "model_final =  torch.load('/content/drive/MyDrive/Colab Notebooks/model_mlp_final.pth', map_location=device)\n",
        "\n",
        "model = MultilayerPerceptron2()\n",
        "\n",
        "\n",
        "model.load_state_dict(model_final['model'])\n",
        "model = model.to(device)\n",
        "model.eval()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Below cell saves the mlp model as an ONNX file which we can then use with OnnxBridge."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "cKJt9Tmx7N2t",
        "vscode": {
          "languageId": "python"
        }
      },
      "outputs": [],
      "source": [
        "output_onnx = str(Path(\"mlp_model.onnx\"))\n",
        "input_volume = torch.randn(1, 1, 45, 54, 45)\n",
        "\n",
        "# Export an ONNX model.\n",
        "with torch.no_grad():\n",
        "    torch.onnx.export(\n",
        "        model=model,\n",
        "        args=(input_volume),\n",
        "        f=output_onnx,\n",
        "        opset_version=13,\n",
        "        verbose=True,\n",
        "        input_names=[\"image\"],\n",
        "        output_names=[ \"score\"],\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Successfully saved mlp_model.onnx !!"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Xonsh",
      "language": "xonsh",
      "name": "xonsh"
    },
    "language_info": {
      "name": "xonsh"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
